Federated Learning: Create Multi-Armed Bandit


In [1]:
%load_ext autoreload
%autoreload 2

import syft as sy
from syft.serde import protobuf
from syft_proto.execution.v1.plan_pb2 import Plan as PlanPB
from syft_proto.execution.v1.state_pb2 import State as StatePB
from syft.grid.clients.static_fl_client import StaticFLClient
from syft.execution.state import State
from syft.execution.placeholder import PlaceHolder
from syft.execution.translation import TranslationTarget

import torch as th
from torch import nn

import os
import websockets
import json
import requests

sy.make_hook(globals())
hook.local_worker.framework = None
th.random.manual_seed(1)


Setting up Sandbox...
Done!
Out[1]:
<torch._C.Generator at 0x106d6d170>

In [2]:
import scipy.stats as ss
from IPython.display import clear_output
from IPython import display
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import bernoulli
import random

num_possibilities = 24

def draw_plot(_alpha_beta, time_step=8888):
    fig = plt.figure()
    fig.canvas.draw()
    plt.xlim(0,1)
    x = np.linspace(0, 1, 100)
    tempy = []

    for idx in range(len(_alpha_beta[0])):
        a = _alpha_beta[0][idx]
        b = _alpha_beta[1][idx]
        y = ss.beta.pdf(x, a,b)
        plt.plot(x,y,label=idx)

    plt.ylim(0,30)
    plt.legend()
    plt.title('@ time_step: {}'.format(time_step))
    display.display(plt.gcf())

class simulator():
    def __init__(self, slots=[.1, .6, .8]):
        self.slots = slots
        self.action_space = list(range(len(slots)))

    def simulate(self, slot_idx):
        return bernoulli.rvs(self.slots[slot_idx]) 

    def simulate_ui(self, slot_idx):
        rwd = input(f'displaying with UI config {slot_idx} out of {len(self.slots)} options, input 1 for click 0 for no')
        return rwd

rand_rates = [min(random.random(), .8) - (random.random() / 100) for iter in range(num_possibilities)]

print('rand_rates', rand_rates)

env = simulator(rand_rates)
env.action_space

def run_simulation(n):
    _one_vec= [1.0] * num_possibilities
    _blank_vec= [0.0] * num_possibilities
    alphas = th.tensor(_one_vec, requires_grad=False)
    betas = th.tensor(_one_vec, requires_grad=False)

    samples_from_beta_distr = {}
    time_step = 0

    for x in range(n):
        print(x)
        rwd_vec = _blank_vec[:]
        sampled_vec = _blank_vec[:]
        for k in range(num_possibilities):#env.action_space:
            samples_from_beta_distr[k] = np.random.beta(alphas[k], betas[k])

        print('samples_from_beta_distr', samples_from_beta_distr)

        selected_action = max(samples_from_beta_distr, key=samples_from_beta_distr.get)
        reward = env.simulate(selected_action)
        time_step += 1

        print('selected action: ', selected_action, 'rwd: ',  reward)

        rwd_vec[selected_action] = float(reward)
        sampled_vec[selected_action] = 1

        print('updated rewd vec: ', rwd_vec)

        (alphas, betas) = bandit_thompson(th.tensor(rwd_vec),th.tensor(sampled_vec), alphas, betas)

        print('time_step: ', x, 'new params: ', alphas, betas)

        if x % 5 == 0:
            draw_plot((alphas, betas), x)

    return (alphas, betas)

''' thompson sampling bandit '''
# the first elem of alpha_beta is the alpha parameter for a beta distr for i-th option where i = index
# the 2nd elem of alpha_beta is the beta parameter ... 
# alphabeta will never be more than 2 elements but each alpha and beta vector could have more elements if we have more options
# this organization allows us to do vectorized updating of the params

_blank_vec= [0.0] * num_possibilities
_one_vec = [1.0] * num_possibilities

print(_blank_vec, len(_blank_vec))

alphas = th.tensor(_one_vec, requires_grad=False)
betas = th.tensor(_one_vec, requires_grad=False)

rwd = th.tensor(_blank_vec)
samples = th.tensor(_blank_vec)
bandit_args_th = [rwd, samples, alphas, betas]
bandit_th_args_shape = [rwd.shape, samples.shape, alphas.shape, betas.shape]

@sy.func2plan(args_shape=bandit_th_args_shape)
def bandit_thompson(reward, sample_vector, alphas, betas):
    prev_alpha = alphas
    prev_beta = betas

    alphas = prev_alpha.add(reward)
    betas = prev_beta.add(sample_vector.sub(reward))

    return (alphas, betas)
        
final_alphas, final_betas = run_simulation(20)


rand_rates [0.297358606175418, 0.05603661127506938, 0.6535666667862454, 0.7976295943783823, 0.40277829020538214, 0.5026309217184822, 0.13330716707173468, 0.7912418987726315, 0.05130823808661064, 0.7988684286182635, 0.7233578793010739, 0.13634823264616455, 0.2489012610330408, 0.6557618490434519, 0.005809536754381204, 0.6308226260013383, 0.44080938576122314, 0.3495717055632467, 0.7409902795102667, 0.6830623924209595, 0.636013220359687, 0.2640373207941863, 0.21449798032700132, 0.7196567097740026]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] 24
0
samples_from_beta_distr {0: 0.5506979372068089, 1: 0.12734818511988877, 2: 0.3412299350587927, 3: 0.8054136799715076, 4: 0.11273495998826488, 5: 0.2203243263062467, 6: 0.8511804800696697, 7: 0.8272297145067081, 8: 0.7347109102716323, 9: 0.16416247817217394, 10: 0.9116400979973631, 11: 0.552922863768152, 12: 0.409619694092394, 13: 0.053711362851815186, 14: 0.4102858637271981, 15: 0.1755709112810346, 16: 0.28668256839336653, 17: 0.7524991798902002, 18: 0.15638440379315077, 19: 0.6385798660745896, 20: 0.28964771927254906, 21: 0.4114826385191182, 22: 0.2293564552043075, 23: 0.5641615291770792}
selected action:  10 rwd:  1
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  0 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1.]) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1.])
1
samples_from_beta_distr {0: 0.05526973925225801, 1: 0.9363731604539063, 2: 0.465605153011164, 3: 0.12581031566530032, 4: 0.6687073970492825, 5: 0.0688246970269409, 6: 0.06940331839998404, 7: 0.9944975517768292, 8: 0.5281367997412384, 9: 0.27433989397079594, 10: 0.962324819881041, 11: 0.84290418547375, 12: 0.9244810310422594, 13: 0.5363969521889174, 14: 0.26385278124488276, 15: 0.10954437908648133, 16: 0.32328221248938677, 17: 0.9945727959397604, 18: 0.426383213935463, 19: 0.020936664058784523, 20: 0.4267006237946333, 21: 0.5151407316325075, 22: 0.6693690965319055, 23: 0.7915986862793059}
selected action:  17 rwd:  0
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  1 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1.]) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2.,
        1., 1., 1., 1., 1., 1.])
2
samples_from_beta_distr {0: 0.7149540620152979, 1: 0.6274469116797207, 2: 0.2689362121261097, 3: 0.40422501062140864, 4: 0.8028341777614662, 5: 0.019586249997208913, 6: 0.7984343976109964, 7: 0.11227888847772868, 8: 0.11025359222209415, 9: 0.5010775561089528, 10: 0.26781037838055227, 11: 0.3931497465966921, 12: 0.6840780651105326, 13: 0.9633830187751966, 14: 0.11545520670841547, 15: 0.46031173297852973, 16: 0.23529591471963948, 17: 0.8789725805675713, 18: 0.7654694316983518, 19: 0.7707328151974192, 20: 0.13500166038775147, 21: 0.30859018881347205, 22: 0.22780531084144628, 23: 0.30704381125170516}
selected action:  13 rwd:  1
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  2 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 2., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1.]) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2.,
        1., 1., 1., 1., 1., 1.])
3
samples_from_beta_distr {0: 0.5008524795574897, 1: 0.7056962141022319, 2: 0.367156939182533, 3: 0.727779444686781, 4: 0.5244261681444119, 5: 0.18500644521757653, 6: 0.8666370096647015, 7: 0.5535315128480677, 8: 0.4921355918632565, 9: 0.21494070121617023, 10: 0.3921774398313786, 11: 0.1269807225861865, 12: 0.9238381007513979, 13: 0.9728075801055874, 14: 0.8745181453032552, 15: 0.3932236383790028, 16: 0.6315723870229063, 17: 0.06665402990532296, 18: 0.1286950326858251, 19: 0.7885621117345633, 20: 0.8984930148214297, 21: 0.9015410022676688, 22: 0.301042247559214, 23: 0.3861131899341755}
selected action:  13 rwd:  1
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  3 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1.]) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2.,
        1., 1., 1., 1., 1., 1.])
4
samples_from_beta_distr {0: 0.13848779091904423, 1: 0.5197625067781384, 2: 0.14262615127646366, 3: 0.2797798101320302, 4: 0.3298988593741842, 5: 0.23620519517908165, 6: 0.3333982790875833, 7: 0.041444508498744426, 8: 0.6867486548154904, 9: 0.3327732536677623, 10: 0.8468891275215445, 11: 0.11963267574245129, 12: 0.8023926672962596, 13: 0.39619467020285454, 14: 0.5875112104068217, 15: 0.6832789198157829, 16: 0.9300713918896626, 17: 0.5004319145006589, 18: 0.4879605692481827, 19: 0.906440366274931, 20: 0.9424863804079436, 21: 0.44619580204215087, 22: 0.871041277591744, 23: 0.7736776928920787}
selected action:  20 rwd:  1
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
time_step:  4 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 2., 1., 1., 1.]) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2.,
        1., 1., 1., 1., 1., 1.])
5
samples_from_beta_distr {0: 0.47314638863837677, 1: 0.07978892984938589, 2: 0.222324699368296, 3: 0.16034178529127938, 4: 0.2980515350590774, 5: 0.13869773473203273, 6: 0.7330845778452255, 7: 0.7993029673258709, 8: 0.5420167940883778, 9: 0.6411556318567913, 10: 0.2134087941262705, 11: 0.08093793094311227, 12: 0.24148391524278792, 13: 0.8979076395623248, 14: 0.7291708166829659, 15: 0.26523576745837457, 16: 0.3647285273983646, 17: 0.10161606127125439, 18: 0.8858706886258826, 19: 0.23656131503068595, 20: 0.8982079860377098, 21: 0.8634390929524525, 22: 0.8229970633818392, 23: 0.19140454648680622}
selected action:  20 rwd:  1
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
time_step:  5 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 3., 1., 1., 1.]) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2.,
        1., 1., 1., 1., 1., 1.])
6
samples_from_beta_distr {0: 0.1401428801485443, 1: 0.974147463188781, 2: 0.30376828455318594, 3: 0.2339808208046095, 4: 0.28862818964283304, 5: 0.019717153621291665, 6: 0.037667397377785, 7: 0.5805433607444153, 8: 0.7896349439444492, 9: 0.3452625939645493, 10: 0.2723375008915747, 11: 0.9680199747969336, 12: 0.5261443242933898, 13: 0.7532022076115457, 14: 0.6971128282870356, 15: 0.858778497982601, 16: 0.43286767173627794, 17: 0.1683095911489608, 18: 0.1225194324674314, 19: 0.567170548353685, 20: 0.6473972014142961, 21: 0.4620200595430399, 22: 0.09413591565081764, 23: 0.9192375290713243}
selected action:  1 rwd:  0
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  6 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 3., 1., 1., 1.]) tensor([1., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2.,
        1., 1., 1., 1., 1., 1.])
7
samples_from_beta_distr {0: 0.9230257456030304, 1: 0.3645578100627148, 2: 0.035751801143866904, 3: 0.5163825333016683, 4: 0.7295101142677434, 5: 0.3323377137166891, 6: 0.22768698061508033, 7: 0.48412559077573286, 8: 0.906553717357562, 9: 0.4535226821169463, 10: 0.6892427607513638, 11: 0.028546127172355832, 12: 0.888217361234233, 13: 0.6172741733879068, 14: 0.4738412264028258, 15: 0.40325916949241236, 16: 0.5898762939432933, 17: 0.3662555138448933, 18: 0.7067267481997584, 19: 0.3364167243371776, 20: 0.640031181286888, 21: 0.8474463483291111, 22: 0.7016873869752583, 23: 0.07036476962761358}
selected action:  0 rwd:  0
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  7 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 3., 1., 1., 1.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2.,
        1., 1., 1., 1., 1., 1.])
8
samples_from_beta_distr {0: 0.05200012205782353, 1: 0.2904588505686596, 2: 0.14680201010968635, 3: 0.8441923260930786, 4: 0.30877806506706446, 5: 0.42528936085289665, 6: 0.9358085204635976, 7: 0.9617639845921127, 8: 0.8611814057280913, 9: 0.6715686031289314, 10: 0.7641966204317939, 11: 0.33903901440983464, 12: 0.455141863329618, 13: 0.5042246211659381, 14: 0.26735158605238235, 15: 0.9556703034251757, 16: 0.5887054896198906, 17: 0.0016204648865946931, 18: 0.7264641298394964, 19: 0.886252488481099, 20: 0.4125354337797974, 21: 0.2974871470344012, 22: 0.9578466944377555, 23: 0.35280515758000336}
selected action:  7 rwd:  1
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  8 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 3., 1., 1., 1.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2.,
        1., 1., 1., 1., 1., 1.])
9
samples_from_beta_distr {0: 0.7419689093889408, 1: 0.15529480720359373, 2: 0.8221530452311481, 3: 0.7563261261857206, 4: 0.04681377398404287, 5: 0.7941902701191741, 6: 0.1015800840214335, 7: 0.7844727565099932, 8: 0.8953709476303778, 9: 0.4985647507773594, 10: 0.7024491768582631, 11: 0.18423583009803784, 12: 0.762980410909308, 13: 0.5821922200666525, 14: 0.7501712554799144, 15: 0.759222862910083, 16: 0.6515800146895993, 17: 0.7942967034596281, 18: 0.26329567474526666, 19: 0.400140575641863, 20: 0.9783677030030322, 21: 0.7742750879447619, 22: 0.5709717136309612, 23: 0.05749662450975622}
selected action:  20 rwd:  1
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
time_step:  9 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 4., 1., 1., 1.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2.,
        1., 1., 1., 1., 1., 1.])
10
samples_from_beta_distr {0: 0.4139190172336117, 1: 0.22274936135828985, 2: 0.1663903134460372, 3: 0.5413620483595575, 4: 0.642383996219157, 5: 0.24312790497569967, 6: 0.14264179308101835, 7: 0.16688451143264402, 8: 0.058448372369845375, 9: 0.2484477375177364, 10: 0.8793901803686325, 11: 0.13838535388041365, 12: 0.5502816234576827, 13: 0.5495623686596361, 14: 0.440036717435472, 15: 0.6041343530773836, 16: 0.03778897886365288, 17: 0.525824678374354, 18: 0.4489100473324218, 19: 0.1869245667132453, 20: 0.471640822572125, 21: 0.5956641592073141, 22: 0.01699599557776346, 23: 0.8625867356778912}
selected action:  10 rwd:  0
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  10 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 4., 1., 1., 1.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 1., 1., 2.,
        1., 1., 1., 1., 1., 1.])
11
samples_from_beta_distr {0: 0.023941136989173013, 1: 0.5633849781493063, 2: 0.035682132704092664, 3: 0.5266266067739089, 4: 0.43959997747319784, 5: 0.5769594742336499, 6: 0.6054956295410395, 7: 0.4537087763146745, 8: 0.2846378502312384, 9: 0.8032369940698173, 10: 0.31815064638565715, 11: 0.12367429677200369, 12: 0.7725832994780015, 13: 0.9547402557719097, 14: 0.20564510113499096, 15: 0.9655792271787055, 16: 0.4919510387616295, 17: 0.09756725237193034, 18: 0.19091708361884224, 19: 0.36388415747570235, 20: 0.8923353913026014, 21: 0.5736379297690785, 22: 0.3684560066072354, 23: 0.9905344317526267}
selected action:  23 rwd:  1
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
time_step:  11 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 4., 1., 1., 2.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 1., 1., 2.,
        1., 1., 1., 1., 1., 1.])
12
samples_from_beta_distr {0: 0.21949834489386502, 1: 0.05628738265183388, 2: 0.84326397025065, 3: 0.9491711447110887, 4: 0.0006033228265391446, 5: 0.5162563427268354, 6: 0.1565325710456959, 7: 0.21007858631087178, 8: 0.6567595100939616, 9: 0.3384202759966859, 10: 0.3320112769724615, 11: 0.2666163433107313, 12: 0.41521519813437546, 13: 0.7881101022242507, 14: 0.4307960319657628, 15: 0.9989568162107866, 16: 0.3207213556347316, 17: 0.08999196461546352, 18: 0.9234737000316668, 19: 0.7846278916029781, 20: 0.9247840784674691, 21: 0.4954199373872133, 22: 0.8585124507503339, 23: 0.9619166634773394}
selected action:  15 rwd:  0
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  12 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 4., 1., 1., 2.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 2., 1., 2.,
        1., 1., 1., 1., 1., 1.])
13
samples_from_beta_distr {0: 0.31489357195970086, 1: 0.09087247015210756, 2: 0.42286662961752935, 3: 0.7856546617818616, 4: 0.5708594557846264, 5: 0.04672789735504838, 6: 0.7005600759912332, 7: 0.6188625727719199, 8: 0.7685817651132113, 9: 0.6836479338095973, 10: 0.412547943176988, 11: 0.3673065988011884, 12: 0.1836447559962553, 13: 0.7385375744937361, 14: 0.25019083694667693, 15: 0.2044921946156866, 16: 0.5512867672148549, 17: 0.5650508998333585, 18: 0.9878987411301382, 19: 0.8607390807292926, 20: 0.908389130503321, 21: 0.3714578122894905, 22: 0.536392936493946, 23: 0.9986687386717062}
selected action:  23 rwd:  0
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  13 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 4., 1., 1., 2.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 2., 1., 2.,
        1., 1., 1., 1., 1., 2.])
14
samples_from_beta_distr {0: 0.8069684115229181, 1: 0.8640798960523162, 2: 0.4826988881769434, 3: 0.7076996308162969, 4: 0.790638458727509, 5: 0.9838776972082116, 6: 0.5311155032768353, 7: 0.6856083532728171, 8: 0.48820624408933383, 9: 0.9916808304036145, 10: 0.50885454173307, 11: 0.27187030305305687, 12: 0.06926106106727559, 13: 0.8385561938099774, 14: 0.704268265170965, 15: 0.5130152862053922, 16: 0.38935566724581694, 17: 0.3970148520559469, 18: 0.1683855969211513, 19: 0.8484727240268563, 20: 0.6103418264007191, 21: 0.5517137783256966, 22: 0.49659040009620314, 23: 0.17496169627188582}
selected action:  9 rwd:  1
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  14 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 2., 1., 2., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 4., 1., 1., 2.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 2., 1., 2.,
        1., 1., 1., 1., 1., 2.])
15
samples_from_beta_distr {0: 0.08441941473105481, 1: 0.45860929914294407, 2: 0.559234956433128, 3: 0.019415003376341767, 4: 0.8890203047951274, 5: 0.004462665258097771, 6: 0.24255476282249028, 7: 0.4657380065463783, 8: 0.5586388001828917, 9: 0.426874357091701, 10: 0.33474190226130485, 11: 0.6908202822207221, 12: 0.7551423074184013, 13: 0.9018793791696056, 14: 0.8070329139484025, 15: 0.7729865787203681, 16: 0.5816247956778041, 17: 0.31500938132212447, 18: 0.8696541912427456, 19: 0.2572668018630663, 20: 0.8362729671741547, 21: 0.7037300411106899, 22: 0.575568255380829, 23: 0.5924351219002928}
selected action:  13 rwd:  0
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  15 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 2., 1., 2., 2., 1., 1., 3., 1., 1., 1., 1.,
        1., 1., 4., 1., 1., 2.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 2., 1., 2., 1., 2.,
        1., 1., 1., 1., 1., 2.])
16
samples_from_beta_distr {0: 0.5168527888599397, 1: 0.7066230734825681, 2: 0.2013856887954505, 3: 0.1316701986290085, 4: 0.2700154973568112, 5: 0.7897202705213702, 6: 0.10965646739899967, 7: 0.7620403957740736, 8: 0.6613910742845466, 9: 0.25128792076220796, 10: 0.8723314979987878, 11: 0.1842910040522458, 12: 0.8494624572542318, 13: 0.5119488622102668, 14: 0.6101457772110034, 15: 0.0891821078549171, 16: 0.9103133945203339, 17: 0.17611432876625704, 18: 0.9194794921359841, 19: 0.09240626489424283, 20: 0.6824103891525054, 21: 0.1492838632243608, 22: 0.7847791282254496, 23: 0.3458127025035864}
selected action:  18 rwd:  1
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  16 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 2., 1., 2., 2., 1., 1., 3., 1., 1., 1., 1.,
        2., 1., 4., 1., 1., 2.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 1., 2., 1., 2., 1., 2.,
        1., 1., 1., 1., 1., 2.])
17
samples_from_beta_distr {0: 0.2840401203187039, 1: 0.842367090470767, 2: 0.4445953595271404, 3: 0.7988370624505824, 4: 0.7785231744563916, 5: 0.1974018619823125, 6: 0.2507665181906877, 7: 0.5819362459682696, 8: 0.8764602192108677, 9: 0.914346117421789, 10: 0.47145652172593117, 11: 0.14140837193196007, 12: 0.9951174402488363, 13: 0.7226898835714252, 14: 0.6832564897785053, 15: 0.11925981120562829, 16: 0.36687962462537943, 17: 0.2393232468100464, 18: 0.05975261700632668, 19: 0.6721089734064287, 20: 0.3152655681928277, 21: 0.7682512644142142, 22: 0.8933987320124293, 23: 0.27846036148655706}
selected action:  12 rwd:  0
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  17 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 2., 1., 2., 2., 1., 1., 3., 1., 1., 1., 1.,
        2., 1., 4., 1., 1., 2.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 2., 2., 1., 2., 1., 2.,
        1., 1., 1., 1., 1., 2.])
18
samples_from_beta_distr {0: 0.3713856504549808, 1: 0.4355562535428058, 2: 0.4679950082756568, 3: 0.4156312639939726, 4: 0.20979387936834784, 5: 0.8385668199274243, 6: 0.022742235578157747, 7: 0.8461174371631349, 8: 0.3009444473957157, 9: 0.9128802789705209, 10: 0.5039236100745127, 11: 0.42238274488101946, 12: 0.2749461829802582, 13: 0.7885836745906996, 14: 0.14268527408861129, 15: 0.2951486574966772, 16: 0.035709661967266576, 17: 0.6344845140005503, 18: 0.13910494125285813, 19: 0.3655064818595665, 20: 0.995889269568735, 21: 0.05041477002298604, 22: 0.7376208472563899, 23: 0.46890558377229724}
selected action:  20 rwd:  1
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
time_step:  18 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 2., 1., 2., 2., 1., 1., 3., 1., 1., 1., 1.,
        2., 1., 5., 1., 1., 2.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 2., 2., 1., 2., 1., 2.,
        1., 1., 1., 1., 1., 2.])
19
samples_from_beta_distr {0: 0.07036372611869618, 1: 0.44173404974969976, 2: 0.5607874165271313, 3: 0.6481677441701276, 4: 0.2156904412309767, 5: 0.876352086339783, 6: 0.3006043167386701, 7: 0.7370367997928543, 8: 0.5018840142997841, 9: 0.8529051793800777, 10: 0.651589866378827, 11: 0.352943137754137, 12: 0.27810011950183416, 13: 0.6771166004721451, 14: 0.915840094966665, 15: 0.3656478959024862, 16: 0.5039979041604895, 17: 0.06890094022333258, 18: 0.5590439664455854, 19: 0.3531175878448488, 20: 0.812039534097957, 21: 0.8615785837220523, 22: 0.6843142962177696, 23: 0.3026153926605436}
selected action:  14 rwd:  0
updated rewd vec:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_step:  19 new params:  tensor([1., 1., 1., 1., 1., 1., 1., 2., 1., 2., 2., 1., 1., 3., 1., 1., 1., 1.,
        2., 1., 5., 1., 1., 2.]) tensor([2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 2., 2., 2., 2., 1., 2.,
        1., 1., 1., 1., 1., 2.])

In [3]:
#sanity check
print(np.argmax(final_alphas), np.argmax(final_betas))
print(np.argmin(final_alphas), np.argmin(final_betas))

rand_rates[18]


tensor(20) tensor(0)
tensor(0) tensor(2)
Out[3]:
0.7409902795102667

Step 3: Define Averaging Plan

Averaging Plan is executed by PyGrid at the end of the cycle, to average diffs submitted by workers and update the model and create new checkpoint for the next cycle.

Diff is the difference between client-trained model params and original model params, so it has same number of tensors and tensor's shapes as the model parameters.

We define Plan that processes one diff at a time. Such Plans require iterative_plan flag set to True in server_config when hosting FL model to PyGrid.

Plan below will calculate simple mean of each parameter.


In [4]:
@sy.func2plan()
def avg_plan(avg, item, num):
    new_avg = []

    for i, param in enumerate(avg):
        new_avg.append((avg[i] * num + item[i]) / (num + 1))
        
    return new_avg

# Build the Plan
_ = avg_plan.build(bandit_args_th, bandit_args_th
, th.tensor([1.0]))

In [5]:
# Let's check Plan contents
print(avg_plan.code)


def avg_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9):
    var_0 = arg_1.__mul__(arg_9)
    var_1 = var_0.__add__(arg_5)
    var_2 = arg_9.__add__(1)
    out_1 = var_1.__truediv__(var_2)
    var_3 = arg_2.__mul__(arg_9)
    var_4 = var_3.__add__(arg_6)
    var_5 = arg_9.__add__(1)
    out_2 = var_4.__truediv__(var_5)
    var_6 = arg_3.__mul__(arg_9)
    var_7 = var_6.__add__(arg_7)
    var_8 = arg_9.__add__(1)
    out_3 = var_7.__truediv__(var_8)
    var_9 = arg_4.__mul__(arg_9)
    var_10 = var_9.__add__(arg_8)
    var_11 = arg_9.__add__(1)
    out_4 = var_10.__truediv__(var_11)
    return out_1, out_2, out_3, out_4

In [6]:
# Test averaging plan
# Pretend there're diffs, all params of which are ones * dummy_coeffs
dummy_coeffs = [1, 5.5, 7, 55]
dummy_diffs = [[th.ones_like(param) * i for param in bandit_args_th] for i in dummy_coeffs]
mean_coeff = th.tensor(dummy_coeffs).mean().item()

# Remove original function to make sure we execute traced Plan
avg_plan.forward = None

# Calculate avg value using our plan
avg = dummy_diffs[0]

for i, diff in enumerate(dummy_diffs[1:]):
    avg = avg_plan(list(avg), diff, th.tensor([i + 1]))

# Avg should be ones*mean_coeff for each param
for i, param in enumerate(bandit_args_th):
    expected = th.ones_like(param) * mean_coeff
    assert avg[i].eq(expected).all(), f"param #{i}"

Step 4: Host in PyGrid

Let's now host everything in PyGrid so that it can be accessed by worker libraries (syft.js, KotlinSyft, SwiftSyft, or even PySyft itself).

First, we need a function to send websocket messages to PyGrid.


In [7]:
async def sendWsMessage(data):
    async with websockets.connect('ws://' + gatewayWsUrl) as websocket:
        await websocket.send(json.dumps(data))
        message = await websocket.recv()
        return json.loads(message)

Follow PyGrid README.md to build openmined/grid-gateway image from the latest dev branch and spin up PyGrid using docker-compose up --build.

Define name, version, configs.


In [8]:
# Default gateway address when running locally 
gatewayWsUrl = "127.0.0.1:5000"
grid = StaticFLClient(id="test", address=gatewayWsUrl, secure=False)
grid.connect()

# These are the name/version you use in worker
name = "bandit"
version = "1.0.0"

client_config = {
    "name": name,
    "version": version,
    "batch_size": 64,
    "lr": 0.005,
    "max_updates": 100  # custom syft.js option that limits number of training loops per worker
}

server_config = {
    "min_workers": 1,
    "max_workers": 1,
    "pool_selection": "random",
    "do_not_reuse_workers_until_cycle": 20,
    "cycle_length": 28800,  # max cycle length in seconds
    "num_cycles": 200,  # max number of cycles
    "max_diffs": 1,  # number of diffs to collect before avg
    "minimum_upload_speed": 0,
    "minimum_download_speed": 0,
    "iterative_plan": True  # tells PyGrid that avg plan is executed per diff
}

In [9]:
model_params_state = State(
    state_placeholders=[
        PlaceHolder().instantiate(param)
        for param in bandit_args_th
    ]
)

response = grid.host_federated_training(
    model=model_params_state,
    client_plans={'training_plan': bandit_thompson},
    client_protocols={},
    server_averaging_plan=avg_plan,
    client_config=client_config,
    server_config=server_config
)

print("Host response:", response)


Host response: {'type': 'model_centric/host-training', 'data': {'status': 'success'}}

Make authentication request:


In [10]:
auth_request = {
    "type": "model_centric/authenticate",
    "data": {
        "model_name": name,
        "model_version": version,
    }
}

auth_response = await sendWsMessage(auth_request)

print('Auth response: ', json.dumps(auth_response, indent=2))


Auth response:  {
  "type": "model_centric/authenticate",
  "data": {
    "status": "success",
    "worker_id": "4fdfaae8-c4b6-4b12-bb59-fa9de471ca54"
  }
}

Make the cycle request:


In [11]:
cycle_request = {
    "type": "model_centric/cycle-request",
    "data": {
        "worker_id": auth_response['data']['worker_id'],
        "model": name,
        "version": version,
        "ping": 1,
        "download": 10,
        "upload": 10,
    }
}

cycle_response = await sendWsMessage(cycle_request)

print('Cycle response:', json.dumps(cycle_response, indent=2))

worker_id = auth_response['data']['worker_id']
request_key = cycle_response['data']['request_key']
model_id = cycle_response['data']['model_id'] 
training_plan_id = cycle_response['data']['plans']['training_plan']


Cycle response: {
  "type": "model_centric/cycle-request",
  "data": {
    "status": "accepted",
    "request_key": "581db59375959ec00518b7b9ad1ec7f98678d5edf7b7724f9509266ec5b60139",
    "version": "1.0.0",
    "model": "bandit",
    "plans": {
      "training_plan": 2
    },
    "protocols": {},
    "client_config": {
      "name": "bandit",
      "version": "1.0.0",
      "batch_size": 64,
      "lr": 0.005,
      "max_updates": 100
    },
    "model_id": 1
  }
}

Step 5: Train

To train hosted model, use the multi-armed bandit example in syft.js.